import os
import re
from tqdm import tqdm
from openai import OpenAI
import json
import time
import random
import torch
import ast
import numpy as np
from abc import ABC, abstractmethod
from torch.nn import functional as F
from collections import defaultdict
from nltk.stem import WordNetLemmatizer
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

from utils.util import read_json, write_json, read_txt
from utils.token_count_decorator import token_count_decorator
from utils.union_find import UnionFind

class Annotation(ABC):
    def __init__(self, input_data_path):
        self.data = read_json(input_data_path)
        self.prompt = ""
        self.batch_input_path = "dsl_design/data/temp_batch/annotation_batch_input.jsonl"
        self.batch_output_path = "dsl_design/data/temp_batch/annotation_batch_output.jsonl"
        self.concurrent = 20000
    
    @abstractmethod
    def annotate(self):
        '''Main process of annotation'''
        target_to_label = {}
        label_to_target = {}
        annotated_len = 0
        # target_list = random.sample(self.target_extraction(), 10)
        target_list = self.target_extraction()
        self.concurrent = min(self.concurrent, len(target_list))
        for k in tqdm(range(0, len(target_list), self.concurrent)):
            targets_batch = target_list[k:k+self.concurrent]
            self.empty_jsonl_contents()
            for i, name in enumerate(targets_batch):
                self.gpt_batch_store(self.prompt.replace("---TARGET---", name), index=str(i))
            print("Batch stored")
            batch_obj = self.gpt_batch_call()
            print("Batch called, waiting for results... batch id: ", batch_obj.id)
            results = self.gpt_batch_result(batch_obj.id)
            print("Results received")
            for result in results:
                idx = int(result["custom_id"])
                label = result["text"]
                target_to_label[targets_batch[idx]] = label
                label_to_target.setdefault(label, []).append(targets_batch[idx])
                if label != "NONE":
                    annotated_len += 1
            time.sleep(8)
        print("Annotation is finished")
        
        print("num of target: ", len(target_list))
        print("Annotated: ", annotated_len)
        print("Fail: ", len(target_list) - annotated_len)
        print("Fail %: ", format((len(target_list) - annotated_len) / len(target_list) * 100, ".2f"))

        return target_to_label, label_to_target

    @abstractmethod
    def target_extraction(self):
        pass

    @abstractmethod
    def judge_process(self, reply):
        pass

    @token_count_decorator
    def gpt_batch_store(self, content, index):
        standard = {
            "custom_id": "", 
            "method": "POST", 
            "url": "/v1/chat/completions", 
            "body": {
                "model": "gpt-4o-mini", 
                "messages": [
                    {"role": "system", "content": "You are an experimental scientist in the fields of biology and chemistry."},
                    {"role": "user", "content": ""}
                ],
                "max_tokens": 1000
            }
        }
        prompt_unit = standard.copy()
        prompt_unit["body"]["messages"][1]["content"] = content
        prompt_unit["custom_id"] = index
        with open(self.batch_input_path, 'a') as file:
            json_line = json.dumps(prompt_unit)
            file.write(json_line + '\n')

    def gpt_batch_call(self):
        retries = 10
        batch_obj = None
        while retries > 0:
            try:
                client = OpenAI()
                batch_input_file = client.files.create(
                    file=open(self.batch_input_path, "rb"),
                    purpose="batch"
                )
                batch_input_file_id = batch_input_file.id
                batch_obj = client.batches.create(
                    input_file_id=batch_input_file_id,
                    endpoint="/v1/chat/completions",
                    completion_window="24h",
                )
                break
            except Exception as e:
                retries -= 1
                time.sleep(3)
        
        return batch_obj

    def gpt_batch_result(self, batch_id):
        client = OpenAI()
        while True:
            retries = 10
            while retries > 0:
                try:
                    batch = client.batches.retrieve(batch_id)
                    break
                except Exception as e:
                    retries -= 1
                    time.sleep(3)
            if batch.status == "completed":
                output_file_id = batch.output_file_id
                result = client.files.content(output_file_id).content
                result_file_name = self.batch_output_path
                with open(result_file_name, "wb") as file:
                    file.write(result)
                results = []
                with open(result_file_name, "r") as file:
                    for line in file:
                        json_object = json.loads(line.strip())
                        results.append(json_object)
                return self.process_results(results)
            elif batch.status in ["failed", "expired", "cancelled", "cancelling"]:
                print(f"Batch {batch.status}")
                return []
            else:
                time.sleep(3)

    def process_results(self, results):
        return_results = []
        for res in results:
            lines = res["response"]["body"]["choices"][0]["message"]["content"].split("\n")
            for line in lines:
                label = self.judge_process(line.strip())
                if label:
                    return_results.append({
                        "custom_id": res["custom_id"],
                        "text": label
                    })
                    break
            else:
                return_results.append({
                    "custom_id": res["custom_id"],
                    "text": "NONE"
                })
        return return_results

    def empty_jsonl_contents(self):
        if os.path.exists(self.batch_input_path):
            with open(self.batch_input_path, 'w') as file:
                file.write('')

class ComponentAnnotation(Annotation):
    def __init__(self, input_data_path, annotated_store_path=""):
        super().__init__(input_data_path)
        self.annotated_store_path = annotated_store_path if annotated_store_path != "" else input_data_path
        self.prompt = read_txt("dsl_design/data/prompt/component_superclass_annotation_2.txt")

    def annotate(self, store_path):
        flowunit_to_superclass, superclass_to_flowunit = super().annotate()
        write_json(store_path, superclass_to_flowunit)
        self.__merge_recognized_superclass(flowunit_to_superclass)

    def target_extraction(self):
        flowunit_names = set()
        recognized_dicts = [sentence["recognized"] for protocol in self.data for sentence in protocol]
        for recognized_dict in recognized_dicts:
            for flow_units in ["input_flow_units", "output_flow_units"]:
                for unit in recognized_dict.get(flow_units, []):
                    if unit.get("Name", "") != "":
                        flowunit_names.add(unit["Name"])
        return list(flowunit_names)

    def judge_process(self, reply):
        phases = {"Gas", "Liquid", "Solid", "Semi-Solid", "Mixture"}
        types = {"Chemical Compound", "Biological Material", "Reagent", "Physical Object", "File/Data"}
        try:
            p, t = map(str.strip, reply.split(","))
        except Exception as _:
            return None
        
        if p in phases:
            return p
        elif t in types:
            return t
        return None
    
    def __merge_recognized_superclass(self, superclass_mapping:dict):
        data_annotated = self.data.copy()
        for experiment in data_annotated:
            for step in experiment:
                for flow_units in ["input_flow_units", "output_flow_units"]:
                    for unit in step['recognized'].get(flow_units, []):
                        try:
                            unit['Superclass'] = superclass_mapping[unit["Name"]]
                        except Exception as e:
                            continue
        write_json(self.annotated_store_path, data_annotated)

class OperationAnnotation(Annotation):
    """
    Annotate superclass to operations in extracted data or recognized data

    Attributes:
        input_data_path (str): The path of input data, either be extracted file or recognized file
        annotated_store_path (str): The path of output data, default="". If annotated_store_path="", the annotated file will overwrite the input file
    """
    def __init__(self, input_data_path, annotated_store_path=""):
        super().__init__(input_data_path)
        self.annotated_store_path = annotated_store_path if annotated_store_path != "" else input_data_path
        self.prompt = read_txt("dsl_design/data/prompt/operation_superclass_annotation_2.txt")
        self.lemmatizer = WordNetLemmatizer()
        self.__preprocess()

    def annotate(self):
        opcode_to_superclass, superclass_to_opcode =  super().annotate()
        write_json("dsl_design/data/demo/opcode_to_superclass.json", opcode_to_superclass)
        write_json("dsl_design/data/demo/superclass_to_opcode.json", superclass_to_opcode)
        self.__merge_data(opcode_to_superclass)

    def annotate_with_context(self, store_path):
        opcode_to_superclass = {}
        superclass_to_opcode = {}
        annotated_len = 0
        opcode_to_sentences = self.__target_extraction_with_context()
        opcode_list = list(opcode_to_sentences.keys())
        # opcode_list = ["Dilution", "Apply", "Widen", "Stain", "Locate", "Hover", "Hold", "Click", "Construction", "Kill"]
        self.concurrent = min(self.concurrent, len(opcode_list))
        for k in tqdm(range(0, len(opcode_list), self.concurrent)):
            targets_batch = opcode_list[k:k+self.concurrent]
            self.empty_jsonl_contents()
            for i, name in enumerate(targets_batch):
                context = "\n".join(opcode_to_sentences[name])
                self.gpt_batch_store(self.prompt.replace("---TARGET---", name).replace("---CONTEXT---", context), index=str(i))
            print("Batch stored")
            batch_obj = self.gpt_batch_call()
            print("Batch called, waiting for results...")
            results = self.gpt_batch_result(batch_obj.id)
            print("Results received")
            for result in results:
                idx = int(result["custom_id"])
                label = result["text"]
                opcode_to_superclass[targets_batch[idx]] = label
                superclass_to_opcode.setdefault(label, []).append(targets_batch[idx])
                if label != "NONE":
                    annotated_len += 1
            time.sleep(8)
        print("Annotation is finished")
        
        print("num of target: ", len(opcode_list))
        print("Annotated: ", annotated_len)
        print("Fail: ", len(opcode_list) - annotated_len)
        print("Fail %: ", format((len(opcode_list) - annotated_len) / len(opcode_list) * 100, ".2f"))

        write_json(store_path, superclass_to_opcode)
        self.__merge_data(opcode_to_superclass)

    def target_extraction(self):
        return sorted({sentence["opcode"] for protocol in self.data for sentence in protocol})
    
    def __target_extraction_with_context(self):
        opcode_to_sentences = {}
        for protocol in self.data:
            for sentence in protocol:
                opcode_to_sentences.setdefault(sentence["opcode"], []).append(sentence["sentence"])
        sampled_target = {
            opcode: sentences if len(sentences) == 1 else random.sample(sentences, 2) 
            for opcode, sentences in opcode_to_sentences.items()
        }
        return sampled_target
        
    def count_opcode(self):
        count_dict = defaultdict(int)
        for protocol in self.data:
            for sentence in protocol:
                count_dict[sentence["opcode"]] += 1
        
        sorted_count_dict = dict(sorted(count_dict.items(), key=lambda item: item[1], reverse=True))

        return sorted_count_dict

    def judge_process(self, reply):
        operations = [
            "Transfer Operations", 
            "Transformation Operations", 
            "Modification Operations", 
            "Synthesis and Generation Operations", 
            "Detection and Measurement Operations", 
            "Time Control Operations",
            "Material Generation Operations",
            "Data Operations"
        ]
        return reply if reply in operations else None
    
    def __preprocess(self):
        for protocol in self.data:
            for sentence in protocol:
                opcode = re.sub(r'[^a-zA-Z]', '', sentence["opcode"])
                lemma_opcode = self.lemmatizer.lemmatize(opcode.lower(), pos="v")
                try:
                    sentence["opcode"] = lemma_opcode[0].upper() + lemma_opcode[1:]
                except:
                    print(sentence["opcode"])
                    exit(0)

    def __merge_data(self, superclass_mapping:dict):
        for protocol in self.data:
            for sentence in protocol[:]:
                try:
                    operation = superclass_mapping[sentence["opcode"]]
                    if operation == "NONE":
                        protocol.remove(sentence)
                    else:
                        sentence["operation"] = superclass_mapping[sentence["opcode"]]
                except Exception as _:
                    continue
        
        write_json(self.annotated_store_path, self.data)

class AliasJudgement(Annotation):
    def __init__(self, entity="component", recognized_data_path="", same_entities_store_path=""):
        super().__init__(recognized_data_path)
        self.entity = entity
        self.same_entities_store_path =same_entities_store_path
        if entity == "component":
            self.prompt = read_txt("dsl_design/data/prompt/same_component_judgement.txt")
            self.candidates = self.__component_extraction()
        elif entity == "device":
            self.prompt = read_txt("dsl_design/data/prompt/same_device_judgement.txt")
            self.candidates = self.__device_extraction()
        self.judgement_to_entities = {}
        self.embeddings = {}
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', clean_up_tokenization_spaces=True)
        self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.uf = UnionFind()

    # 真正实现时要加一个前置判断：在用 gpt 之前，将原型形式相同的 flowunits 划为一种。
    # 否则会让 gpt 做很多判断大小写和单复数之类的事情，导致 pairwise 过多。
    def annotate(self):
        _, judgement_to_entities = super().annotate()
        self.judgement_to_entities = self.__process(judgement_to_entities)
        same_entites = self.__merge_same()
        write_json(self.same_entities_store_path, same_entites)
        return same_entites

    def target_extraction(self):
        homo_candidates = []

        num_candidates = len(self.candidates)
        self.embeddings = {entity: self.__get_embedding(entity) for entity in tqdm(self.candidates, desc="Get embedding of candidates")}
        
        with tqdm(total=sum(range(num_candidates)), desc="Calculating cosine similarity of two entities") as pbhr:
            for i in range(num_candidates):
                a = self.candidates[i]
                for j in range(i + 1, num_candidates):
                    b = self.candidates[j]
                    sim = cosine_similarity([self.embeddings[a]], [self.embeddings[b]])[0][0]
                    if sim > 0.9:
                        homo_candidates.append(str((a, b)))
                    pbhr.update(1)
        
        print("The num of potential same entities candidate:", len(homo_candidates))

        return homo_candidates
    
    def judge_process(self, reply):
        return reply if reply in ["Yes", "No"] else None
    
    def __get_embedding(self, text):
        inputs = self.tokenizer(text, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        embedding = F.normalize(embedding, p=2, dim=1)
        return embedding.squeeze().numpy()
    
    def __component_extraction(self):
        flowunit_names = set()
        recognized_dicts = [sentence["recognized"] for protocol in self.data for sentence in protocol]
        for recognized_dict in recognized_dicts:
            for flow_units in ["input_flow_units", "output_flow_units"]:
                for unit in recognized_dict.get(flow_units, []):
                    if unit.get("Name", "") != "":
                        flowunit_names.add(unit["Name"])
        return sorted(flowunit_names)
    
    def __device_extraction(self):
        device_names = set()
        recognized_dicts = [sentence["recognized"] for protocol in self.data for sentence in protocol]
        for recognized_dict in recognized_dicts:
            if recognized_dict.get("input_flow_units", []) or recognized_dict.get("output_flow_units", []):
                for device_dict in recognized_dict.get("devices", {}):
                    if (device := device_dict.get("Name", "")) != "":
                        device_names.add(device)
        return sorted(device_names)
    
    def __process(self, label_to_target:dict):
        return {label:[ast.literal_eval(names) for names in target] for label, target in label_to_target.items()}
    
    def __merge_same(self):
        merge_result = {}

        entity_pairs = self.judgement_to_entities.get("Yes", [])
        # 将所有组分加入并查集
        for pair in entity_pairs:
            entity1, entity2 = pair
            self.uf.add(entity1)
            self.uf.add(entity2)
            self.uf.union(entity1, entity2)
        
        merged_entities = {}
        for entity in self.uf.parent:
            root = self.uf.find(entity)
            merged_entities.setdefault(root, []).append(entity)
        entities_list = list(merged_entities.values())
        
        for entities in entities_list:
            if len(entities) == 2:
                name = entities[0] if len(entities[0]) < len(entities[1]) else entities[1]
            else:
                name = self.__select_representation(entities)
            merge_result[name] = entities

        return merge_result

    def __select_representation(self, entities:list):
        if len(entities) == 1:
            return entities[0]
        embeddings = [self.embeddings[component] for component in entities]
        similarity_matrix = cosine_similarity(embeddings)
        avg_similarities = similarity_matrix.mean(axis=1)
        best_index = np.argmax(avg_similarities)
        return entities[best_index]
